{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import adversary.cw as cw\n",
    "from adversary.jsma import SaliencyMapMethod\n",
    "from adversary.fgsm import Attack\n",
    "import torchvision\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.data as Data\n",
    "from models.mnist_model import MnistModel, MLP\n",
    "from torchvision import transforms\n",
    "\n",
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cuda:1'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MNIST_UNDERCOVER_CKPT = './checkpoint/mnist_undercover.pth'\n",
    "device = 'cuda:1' if torch.cuda.is_available() else 'cpu'\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "])\n",
    "\n",
    "mlp = MLP().to(device)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(mlp.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)\n",
    "\n",
    "\n",
    "undercoverNet = MnistModel().to(device)\n",
    "checkpoint = torch.load(MNIST_UNDERCOVER_CKPT, map_location=torch.device(device))\n",
    "undercoverNet.load_state_dict(checkpoint['net'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_test)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=512, shuffle=True, num_workers=4)\n",
    "trainiter = iter(trainloader)\n",
    "\n",
    "testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=512, shuffle=False, num_workers=4)\n",
    "testiter = iter(testloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Take BIM attack as an example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "undercover_gradient_attacker = Attack(undercoverNet, F.cross_entropy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# construct bim adversarial samples\n",
    "# --------------------train---------------------\n",
    "normal_samples, adversarial_samples = [], []\n",
    "for x, y in trainloader:\n",
    "    x, y = x.to(device), y.to(device)\n",
    "    y_pred = undercoverNet(x).argmax(dim=1)\n",
    "    \n",
    "    eps = 0.3\n",
    "    x_adv = undercover_gradient_attacker.i_fgsm(x, y, eps=eps, alpha=1/255, iteration=int(min(eps*255 + 4, 1.25*eps*255)))\n",
    "    y_pred_adv = undercoverNet(x_adv).argmax(dim=1)\n",
    "    selected = (y == y_pred) & (y != y_pred_adv)\n",
    "    normal_samples.append(x[selected].detach().cpu())\n",
    "    adversarial_samples.append(x_adv[selected].detach().cpu())\n",
    "#     break\n",
    "\n",
    "normal_x = torch.cat(normal_samples, dim=0)\n",
    "adversarial_x = torch.cat(adversarial_samples, dim=0)\n",
    "normal_y = torch.zeros(normal_x.shape[0]).long()\n",
    "adversarial_y = torch.ones(adversarial_x.shape[0]).long()\n",
    "\n",
    "dba_trainloader = Data.DataLoader(Data.TensorDataset(torch.cat([normal_x, adversarial_x], dim=0),\n",
    "                                           torch.cat([normal_y, adversarial_y], dim=0)), \n",
    "                                  batch_size=512, shuffle=True, num_workers=4)\n",
    "dba_trainiter = iter(dba_trainloader)\n",
    "\n",
    "# ----------------test---------------------\n",
    "normal_samples, adversarial_samples = [], []\n",
    "for x, y in testloader:\n",
    "    x, y = x.to(device), y.to(device)\n",
    "    y_pred = undercoverNet(x).argmax(dim=1)\n",
    "    \n",
    "    eps = 0.3\n",
    "    x_adv = undercover_gradient_attacker.i_fgsm(x, y, eps=eps, alpha=1/255, iteration=int(min(eps*255 + 4, 1.25*eps*255)))\n",
    "    y_pred_adv = undercoverNet(x_adv).argmax(dim=1)\n",
    "    selected = (y == y_pred) & (y != y_pred_adv)\n",
    "    normal_samples.append(x[selected].detach().cpu())\n",
    "    adversarial_samples.append(x_adv[selected].detach().cpu())\n",
    "#     break\n",
    "\n",
    "normal_x = torch.cat(normal_samples, dim=0)\n",
    "adversarial_x = torch.cat(adversarial_samples, dim=0)\n",
    "normal_y = torch.zeros(normal_x.shape[0]).long()\n",
    "adversarial_y = torch.ones(adversarial_x.shape[0]).long()\n",
    "\n",
    "dba_testloader = Data.DataLoader(Data.TensorDataset(torch.cat([normal_x, adversarial_x], dim=0),\n",
    "                                           torch.cat([normal_y, adversarial_y], dim=0)), \n",
    "                                  batch_size=1024, shuffle=True, num_workers=4)\n",
    "dba_testiter = iter(dba_testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train the mlp\n",
    "epochs = 10\n",
    "for i in range(epochs):\n",
    "    for x, y in dba_trainloader:\n",
    "        optimizer.zero_grad()\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        _, V1 = undercoverNet(x, dba=True)\n",
    "        undercover_adv = undercover_gradient_attacker.fgsm(x, y, False, 1/255)\n",
    "        _, V2 = undercoverNet(undercover_adv, dba=True)\n",
    "        V = torch.cat([V1, V2, V1 - V2, V1 * V2], axis=-1)\n",
    "        y_pred = mlp(V)\n",
    "        loss = criterion(y_pred, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.997010144153764\n"
     ]
    }
   ],
   "source": [
    "# test\n",
    "total, correct = 0, 0\n",
    "for x, y in dba_testloader:\n",
    "    x, y = x.to(device), y.to(device)\n",
    "    _, V1 = undercoverNet(x, dba=True)\n",
    "    undercover_adv = undercover_gradient_attacker.fgsm(x, y, False, 1/255)\n",
    "    _, V2 = undercoverNet(undercover_adv, dba=True)\n",
    "    V = torch.cat([V1, V2, V1 - V2, V1 * V2], axis=-1)\n",
    "    y_pred = mlp(V).argmax(dim=1)\n",
    "    \n",
    "    total += y.size(0)\n",
    "    correct += y_pred.eq(y).sum().item()\n",
    "print(correct / total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
